{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "02114.ipynb",
      "provenance": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/yananma/5_programs_per_day/blob/master/02114.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "O1IOV7VPHSEo",
        "colab_type": "text"
      },
      "source": [
        "## 10.6 求近义词和类比词"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NUmqmmBoIm3z",
        "colab_type": "text"
      },
      "source": [
        "### 10.6.1 使用预训练的词向量"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ARG7Pik2JJSR",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 54
        },
        "outputId": "e670c0c7-8bc8-4e36-881c-b6f010532947"
      },
      "source": [
        "import torch \n",
        "import torchtext.vocab as vocab \n",
        "\n",
        "vocab.pretrained_aliases.keys()"
      ],
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "dict_keys(['charngram.100d', 'fasttext.en.300d', 'fasttext.simple.300d', 'glove.42B.300d', 'glove.840B.300d', 'glove.twitter.27B.25d', 'glove.twitter.27B.50d', 'glove.twitter.27B.100d', 'glove.twitter.27B.200d', 'glove.6B.50d', 'glove.6B.100d', 'glove.6B.200d', 'glove.6B.300d'])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 3
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hwiEiL0AJaUR",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 185
        },
        "outputId": "47f3c328-1fb2-4efa-e97e-80fa177c48af"
      },
      "source": [
        "[key for key in vocab.pretrained_aliases.keys() if 'glove' in key]"
      ],
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "['glove.42B.300d',\n",
              " 'glove.840B.300d',\n",
              " 'glove.twitter.27B.25d',\n",
              " 'glove.twitter.27B.50d',\n",
              " 'glove.twitter.27B.100d',\n",
              " 'glove.twitter.27B.200d',\n",
              " 'glove.6B.50d',\n",
              " 'glove.6B.100d',\n",
              " 'glove.6B.200d',\n",
              " 'glove.6B.300d']"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 4
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8zc4FEUQJqhh",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 50
        },
        "outputId": "5f0767a0-1f2b-407d-eeda-b58c15787cec"
      },
      "source": [
        "glove = vocab.GloVe(name='6B', dim=50)"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            ".vector_cache/glove.6B.zip: 862MB [06:30, 2.21MB/s]                          \n",
            "100%|█████████▉| 398476/400000 [00:09<00:00, 43303.23it/s]\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7-tM7kNNKQ1j",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "1258f7f2-e0c0-48b6-e604-9a20279003cb"
      },
      "source": [
        "len(glove.stoi)"
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "400000"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 6
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GWouuRY6Kd7_",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "d0ec41b8-4ffa-4057-aca9-c79c77901caa"
      },
      "source": [
        "glove.stoi['beautiful'], glove.itos[3366]"
      ],
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(3366, 'beautiful')"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 7
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XCpVc3p3LKbt",
        "colab_type": "text"
      },
      "source": [
        "### 10.6.2 应用预训练词向量"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1g0lxBPtLQTs",
        "colab_type": "text"
      },
      "source": [
        "#### 1 求近义词"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GVSTr44zLTJN",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def knn(W, x, k):\n",
        "    cos = torch.matmul(W, x.view((-1, ))) / (\n",
        "        (torch.sum(W * W, dim=1) + 1e-9).sqrt() * torch.sum(x * x).sqrt())\n",
        "    _, topk = torch.topk(cos, k=k)\n",
        "    topk = topk.cpu().numpy()\n",
        "    return topk, [cos[i].item() for i in topk]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9OEgU59AMBMV",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def get_similar_tokens(query_token, k, embed):\n",
        "    topk, cos = knn(embed.vectors, \n",
        "             embed.vectors[embed.stoi[query_token]], k+1)\n",
        "    for i, c in zip(topk[1:], cos[1:]):\n",
        "        print('cosine sim=%.3f: %s' % (c, (embed.itos[i])))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "nKxHCHhlMwl_",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 67
        },
        "outputId": "c2ff1aa3-0199-4947-eaa1-30c228dd4e6f"
      },
      "source": [
        "get_similar_tokens('chip', 3, glove)"
      ],
      "execution_count": 13,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "cosine sim=0.856: chips\n",
            "cosine sim=0.749: intel\n",
            "cosine sim=0.749: electronics\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "RXlVnG3yM2Jw",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 67
        },
        "outputId": "378d98fb-e32c-46ae-987e-477d6eb4d0b3"
      },
      "source": [
        "get_similar_tokens('baby', 3, glove)"
      ],
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "cosine sim=0.839: babies\n",
            "cosine sim=0.800: boy\n",
            "cosine sim=0.792: girl\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wtA-M87qNCEK",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 67
        },
        "outputId": "11f75e47-de2d-4223-f76c-582d34e9d8ef"
      },
      "source": [
        "get_similar_tokens('beautiful', 3, glove)"
      ],
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "cosine sim=0.921: lovely\n",
            "cosine sim=0.893: gorgeous\n",
            "cosine sim=0.830: wonderful\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fwhmpEehNLRf",
        "colab_type": "text"
      },
      "source": [
        "#### 2 求类比词"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_1zeihR6NSFO",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def get_analogy(token_a, token_b, token_c, embed):\n",
        "    vecs = [embed.vectors[embed.stoi[t]] for t in [token_a, token_b, token_c]]\n",
        "    x = vecs[1] - vecs[0] + vecs[2]\n",
        "    topk, cos = knn(embed.vectors, x, 1)\n",
        "    return embed.itos[topk[0]]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "73jmMKExN0rZ",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "e2c8d185-49c1-4189-91c2-5c48661d620f"
      },
      "source": [
        "get_analogy('man', 'woman', 'son', glove)"
      ],
      "execution_count": 18,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'daughter'"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 18
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vy5WuyH9OF_f",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "86858454-b233-484b-c89b-1ee29f9729e6"
      },
      "source": [
        "get_analogy('beijing', 'china', 'tokyo', glove)"
      ],
      "execution_count": 19,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'japan'"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 19
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "IGHsXxo1ON7u",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "9b5dc599-bc01-41dc-df07-1b45812aba74"
      },
      "source": [
        "get_analogy('bad', 'worst', 'big', glove)"
      ],
      "execution_count": 20,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'biggest'"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 20
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "nmwDnPxKOV92",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "9327b91f-0f7a-4b66-a7bd-1f39d6fc6759"
      },
      "source": [
        "get_analogy('do', 'did', 'go', glove)"
      ],
      "execution_count": 21,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'went'"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 21
        }
      ]
    }
  ]
}